{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N6ZDpd9XzFeN"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "KUu4vOt5zI9d"
      },
      "outputs": [],
      "source": [
        "# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CxmDMK4yupqg"
      },
      "source": [
        "# Generate Artificial Faces with CelebA Progressive GAN Model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/hub/tutorials/tf_hub_generative_image_module\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/hub/tutorials/tf_hub_generative_image_module.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/hub/tutorials/tf_hub_generative_image_module.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/hub/tutorials/tf_hub_generative_image_module.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://tfhub.dev/google/progan-128/1\"><img src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" />See TF Hub model</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sy553YSVmYiK"
      },
      "source": [
        "This Colab demonstrates use of a TF Hub module based on a generative adversarial network (GAN). The module maps from N-dimensional vectors, called latent space, to RGB images.\n",
        "\n",
        "Two examples are provided:\n",
        "* **Mapping** from latent space to images, and\n",
        "* Given a target image, **using gradient descent to find** a latent vector that generates an image similar to the target image."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v4XGxDrCkeip"
      },
      "source": [
        "## Optional prerequisites\n",
        "\n",
        "* Familiarity with [low level Tensorflow concepts](https://www.tensorflow.org/guide/eager).\n",
        "* [Generative Adversarial Network](https://en.wikipedia.org/wiki/Generative_adversarial_network) on Wikipedia.\n",
        "* Paper on Progressive GANs: [Progressive Growing of GANs for Improved Quality, Stability, and Variation](https://arxiv.org/abs/1710.10196)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HK3Q2vIaVw56"
      },
      "source": [
        "### More models\n",
        "[Here](https://tfhub.dev/s?module-type=image-generator) you can find all models currently hosted on [tfhub.dev](https://tfhub.dev/) that can generate images."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q4DN769E2O_R"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KNM3kA0arrUu"
      },
      "outputs": [],
      "source": [
        "# Install imageio for creating animations.  \n",
        "!pip -q install imageio\n",
        "!pip -q install scikit-image\n",
        "!pip install git+https://github.com/tensorflow/docs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "6cPY9Ou4sWs_"
      },
      "outputs": [],
      "source": [
        "#@title Imports and function definitions\n",
        "from absl import logging\n",
        "\n",
        "import imageio\n",
        "import PIL.Image\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "import tensorflow as tf\n",
        "tf.random.set_seed(0)\n",
        "\n",
        "import tensorflow_hub as hub\n",
        "from tensorflow_docs.vis import embed\n",
        "import time\n",
        "\n",
        "try:\n",
        "  from google.colab import files\n",
        "except ImportError:\n",
        "  pass\n",
        "\n",
        "from IPython import display\n",
        "from skimage import transform\n",
        "\n",
        "# We could retrieve this value from module.get_input_shapes() if we didn't know\n",
        "# beforehand which module we will be using.\n",
        "latent_dim = 512\n",
        "\n",
        "\n",
        "# Interpolates between two vectors that are non-zero and don't both lie on a\n",
        "# line going through origin. First normalizes v2 to have the same norm as v1. \n",
        "# Then interpolates between the two vectors on the hypersphere.\n",
        "def interpolate_hypersphere(v1, v2, num_steps):\n",
        "  v1_norm = tf.norm(v1)\n",
        "  v2_norm = tf.norm(v2)\n",
        "  v2_normalized = v2 * (v1_norm / v2_norm)\n",
        "\n",
        "  vectors = []\n",
        "  for step in range(num_steps):\n",
        "    interpolated = v1 + (v2_normalized - v1) * step / (num_steps - 1)\n",
        "    interpolated_norm = tf.norm(interpolated)\n",
        "    interpolated_normalized = interpolated * (v1_norm / interpolated_norm)\n",
        "    vectors.append(interpolated_normalized)\n",
        "  return tf.stack(vectors)\n",
        "\n",
        "# Simple way to display an image.\n",
        "def display_image(image):\n",
        "  image = tf.constant(image)\n",
        "  image = tf.image.convert_image_dtype(image, tf.uint8)\n",
        "  return PIL.Image.fromarray(image.numpy())\n",
        "\n",
        "# Given a set of images, show an animation.\n",
        "def animate(images):\n",
        "  images = np.array(images)\n",
        "  converted_images = np.clip(images * 255, 0, 255).astype(np.uint8)\n",
        "  imageio.mimsave('./animation.gif', converted_images)\n",
        "  return embed.embed_file('./animation.gif')\n",
        "\n",
        "logging.set_verbosity(logging.ERROR)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f5EESfBvukYI"
      },
      "source": [
        "## Latent space interpolation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nJb9gFmRvynZ"
      },
      "source": [
        "### Random vectors\n",
        "\n",
        "Latent space interpolation between two randomly initialized vectors. We will use a TF Hub module [progan-128](https://tfhub.dev/google/progan-128/1) that contains a pre-trained Progressive GAN."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8StEe9x9wGma"
      },
      "outputs": [],
      "source": [
        "progan = hub.load(\"https://tfhub.dev/google/progan-128/1\").signatures['default']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fZ0O5_5Jhwio"
      },
      "outputs": [],
      "source": [
        "def interpolate_between_vectors():\n",
        "  v1 = tf.random.normal([latent_dim])\n",
        "  v2 = tf.random.normal([latent_dim])\n",
        "    \n",
        "  # Creates a tensor with 25 steps of interpolation between v1 and v2.\n",
        "  vectors = interpolate_hypersphere(v1, v2, 50)\n",
        "\n",
        "  # Uses module to generate images from the latent space.\n",
        "  interpolated_images = progan(vectors)['default']\n",
        "\n",
        "  return interpolated_images\n",
        "\n",
        "interpolated_images = interpolate_between_vectors()\n",
        "animate(interpolated_images)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L9-uXoTHuXQC"
      },
      "source": [
        "## Finding closest vector in latent space\n",
        "Fix a target image. As an example use an image generated from the module or upload your own."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "phT4W66pMmko"
      },
      "outputs": [],
      "source": [
        "image_from_module_space = True  # @param { isTemplate:true, type:\"boolean\" }\n",
        "\n",
        "def get_module_space_image():\n",
        "  vector = tf.random.normal([1, latent_dim])\n",
        "  images = progan(vector)['default'][0]\n",
        "  return images\n",
        "\n",
        "def upload_image():\n",
        "  uploaded = files.upload()\n",
        "  image = imageio.imread(uploaded[list(uploaded.keys())[0]])\n",
        "  return transform.resize(image, [128, 128])\n",
        "\n",
        "if image_from_module_space:\n",
        "  target_image = get_module_space_image()\n",
        "else:\n",
        "  target_image = upload_image()\n",
        "\n",
        "display_image(target_image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rBIt3Q4qvhuq"
      },
      "source": [
        "After defining a loss function between the target image and the image generated by a latent space variable, we can use gradient descent to find variable values that minimize the loss."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cUGakLdbML2Q"
      },
      "outputs": [],
      "source": [
        "tf.random.set_seed(42)\n",
        "initial_vector = tf.random.normal([1, latent_dim])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u7MGzDE5MU20"
      },
      "outputs": [],
      "source": [
        "display_image(progan(initial_vector)['default'][0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q_4Z7tnyg-ZY"
      },
      "outputs": [],
      "source": [
        "def find_closest_latent_vector(initial_vector, num_optimization_steps,\n",
        "                               steps_per_image):\n",
        "  images = []\n",
        "  losses = []\n",
        "\n",
        "  vector = tf.Variable(initial_vector)  \n",
        "  optimizer = tf.optimizers.Adam(learning_rate=0.01)\n",
        "  loss_fn = tf.losses.MeanAbsoluteError(reduction=\"sum\")\n",
        "\n",
        "  for step in range(num_optimization_steps):\n",
        "    if (step % 100)==0:\n",
        "      print()\n",
        "    print('.', end='')\n",
        "    with tf.GradientTape() as tape:\n",
        "      image = progan(vector.read_value())['default'][0]\n",
        "      if (step % steps_per_image) == 0:\n",
        "        images.append(image.numpy())\n",
        "      target_image_difference = loss_fn(image, target_image[:,:,:3])\n",
        "      # The latent vectors were sampled from a normal distribution. We can get\n",
        "      # more realistic images if we regularize the length of the latent vector to \n",
        "      # the average length of vector from this distribution.\n",
        "      regularizer = tf.abs(tf.norm(vector) - np.sqrt(latent_dim))\n",
        "      \n",
        "      loss = target_image_difference + regularizer\n",
        "      losses.append(loss.numpy())\n",
        "    grads = tape.gradient(loss, [vector])\n",
        "    optimizer.apply_gradients(zip(grads, [vector]))\n",
        "    \n",
        "  return images, losses\n",
        "\n",
        "\n",
        "num_optimization_steps=200\n",
        "steps_per_image=5\n",
        "images, loss = find_closest_latent_vector(initial_vector, num_optimization_steps, steps_per_image)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pRbeF2oSAcOB"
      },
      "outputs": [],
      "source": [
        "plt.plot(loss)\n",
        "plt.ylim([0,max(plt.ylim())])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KnZkDy2FEsTt"
      },
      "outputs": [],
      "source": [
        "animate(np.stack(images))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GGKfuCdfPQKH"
      },
      "source": [
        "Compare the result to the target:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TK1P5z3bNuIl"
      },
      "outputs": [],
      "source": [
        "display_image(np.concatenate([images[-1], target_image], axis=1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tDt15dLsJwMy"
      },
      "source": [
        "### Playing with the above example\n",
        "If image is from the module space, the descent is quick and converges to a reasonable sample. Try out descending to an image that is **not from the module space**. The descent will only converge if the image is reasonably close to the space of training images.\n",
        "\n",
        "How to make it descend faster and to a more realistic image? One can try:\n",
        "* using different loss on the image difference, e.g., quadratic,\n",
        "* using different regularizer on the latent vector,\n",
        "* initializing from a random vector in multiple runs,\n",
        "* etc.\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "N6ZDpd9XzFeN"
      ],
      "name": "tf_hub_generative_image_module.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
